from visual import *
from random import random

from mpmath import *

class NumericLattice:
    def __init__(self, nSpinsX, nSpinsY, stripeSpacingX, stripeSpacingY):
        self.nSpinsX = nSpinsX
        self.nSpinsY = nSpinsY
       
        self.stripeSpacingX = stripeSpacingX
        self.stripeSpacingY = stripeSpacingY

        self.Ja = -1.0
        self.Jbx = -1.0
        self.Jby = -1.0
        self.Jad = -1.0

        self.phiScale = 1.0
##        self.magS2 = 1.0
        self.temp = 0.001

     
        #  Initializing Arrays

        self.zArray = [ None ] * self.stripeSpacingX
        for x in xrange(self.stripeSpacingX):
            self.zArray[x] = [ 0.0 ] * self.stripeSpacingY

        self.phaseArray = [ None ] * stripeSpacingX
        for x in xrange(self.stripeSpacingX):
            self.phaseArray[x] = [ 0.0 ] * self.stripeSpacingY

        self.phiArray = [ None ] * self.stripeSpacingX
        for x in xrange(self.stripeSpacingX):
            self.phiArray[x] = [ 0.0 ] * self.stripeSpacingY


        self.couplingXArray = [ None ] * self.nSpinsX
        for x in xrange(self.nSpinsX):
            self.couplingXArray[x] = [ 0.0 ] * self.nSpinsX

        self.couplingYArray = [ None ] * self.nSpinsY
        for x in xrange(self.nSpinsY):
            self.couplingYArray[x] = [ 0.0 ] * self.nSpinsY
            

        self.spinArray = [ None ] * self.nSpinsX
        for x in xrange(self.nSpinsX):
            self.spinArray[x] = [ None ] * self.nSpinsY
            for y in xrange(self.nSpinsY):
                self.spinArray[x][y] = [ vector(0.0,0.0,0.0) ]

        self.meanFieldArray = [ None ] * self.nSpinsX
        for x in xrange(self.nSpinsX):
            self.meanFieldArray[x] = [ None ] * self.nSpinsY
            for y in xrange(self.nSpinsY):
                self.meanFieldArray[x][y] = [ vector(0.0,0.0,0.0) ]

        self.torqueArray = [ None ] * self.nSpinsX
        for x in xrange(self.nSpinsX):
            self.torqueArray[x] = [ None ] * self.nSpinsY
            for y in xrange(self.nSpinsY):
                self.torqueArray[x][y] = [ vector(0.0,0.0,0.0) ]


        self.tempArray = [ None ] * self.nSpinsX
        for x in xrange(self.nSpinsX):
            self.tempArray[x] = [ None ] * self.nSpinsY
            for y in xrange(self.nSpinsY):
                self.tempArray[x][y] = [ vector(0.0,0.0,0.0) ]

        self.torqueTempArray = [ None ] * self.nSpinsX
        for x in xrange(self.nSpinsX):
            self.torqueTempArray[x] = [ None ] * self.nSpinsY
            for y in xrange(self.nSpinsY):
                self.torqueTempArray[x][y] = [ vector(0.0,0.0,0.0) ]


##    def setState(self, k, zArray, phaseArray, phiArray, phiScale, phiRatio, magS2):
##        self.magS2 = magS2
##        self.phiScale = phiScale
##        self.phiRatio = phiRatio
##
##        self.zArray = zArray
##        self.phaseArray = phaseArray
##        self.phiArray = phiArray
##
###        k2 = 2.0*pi*vector(0.0,1.0,0.0)
##
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                xModXSpacing = x%self.stripeSpacingX
##                yModYSpacing = y%self.stripeSpacingY
##                
##                xR = x - xModXSpacing
##                yR = y - yModYSpacing               
##
##                if(self.zArray[xModXSpacing][yModYSpacing] == 1.0):
##                    tSigma = sin(self.phiScale*self.phiArray[xModXSpacing][yModYSpacing])                              
##                    tzSigma = cos(self.phiScale*self.phiArray[xModXSpacing][yModYSpacing])
##                else:
##                    tSigma = sin(self.phiScale*self.phiArray[xModXSpacing][yModYSpacing])                              
##                    tzSigma = cos(self.phiScale*self.phiArray[xModXSpacing][yModYSpacing])
##
##                    
##                tAngle = (k.x*x)/self.nSpinsX + (k.y*y)/self.nSpinsY #+ (k2.x*x)/self.nSpinsX + (k2.y*y)/self.nSpinsY
##                self.phaseArray[xModXSpacing][yModYSpacing] = tAngle
##                
##                tSpinX = self.zArray[xModXSpacing][yModYSpacing]*tSigma*cos(tAngle)
##                tSpinY = self.zArray[xModXSpacing][yModYSpacing]*tSigma*sin(tAngle)
##                tSpinZ = self.zArray[xModXSpacing][yModYSpacing]*tzSigma
##
##                self.spinArray[x][y] = vector(tSpinX,tSpinY,tSpinZ)

    def setState(self, k, zArray, phaseArray, phiArray, phiScale, phiRatio, magS2):
        self.magS2 = magS2
        self.phiScale = phiScale
        self.phiRatio = phiRatio

        self.zArray = zArray
        self.phaseArray = phaseArray
        self.phiArray = phiArray

        sn = ellipfun('sn')
        cn = ellipfun('cn')
        dn = ellipfun('dn')

        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                xModXSpacing = x%self.stripeSpacingX
                yModYSpacing = y%self.stripeSpacingY

                xR = x - xModXSpacing
                yR = y - yModYSpacing

                m = .7
                a = k.x*1

                alpha = 1/sqrt(1+cn(a,m))
                beta = sqrt(dn(a,m))/sqrt(dn(a,m)+cn(a,m))
                gamma = sqrt((-1+sqrt(1+cn(a,m)))*(1+sqrt(1-cn(a,m))))/sqrt(cn(a,m)+dn(a,m))

                tAngle = ((k.x*x)/self.nSpinsX + (k.y*y)/self.nSpinsY)*ellipk(m)/3.14159265*2.00
                
                tSpinX = self.zArray[xModXSpacing][yModYSpacing]*alpha*sn(tAngle,m)
                tSpinY = self.zArray[xModXSpacing][yModYSpacing]*beta*cn(tAngle,m)
                tSpinZ = self.zArray[xModXSpacing][yModYSpacing]*gamma*dn(tAngle,m)

                self.spinArray[x][y] = vector(tSpinX,tSpinY,tSpinZ)

    def setSolitonState(self, k, zArray, phaseArray, phiArray, phiScale, phiRatio, magS2):
        self.magS2 = magS2
        self.phiScale = phiScale
        self.phiRatio = phiRatio

        self.zArray = zArray
        self.phaseArray = phaseArray
        self.phiArray = phiArray

        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                xModXSpacing = x%self.stripeSpacingX
                yModYSpacing = y%self.stripeSpacingY

                C = -.5
                A = 1.0
                la = 8.0

                sign = -1+2*((x+y)%2)
                phi = 0
                
##                if sign == 1:
##                    tSpinX = A*sech((x-6.0)/la)*cos(phi)
##                    tSpinY = A*sech((x-6.0)/la)*sin(phi)
##                    tSpinZ = tanh((x-6.0)/la)
##                    
##                if sign == -1:
##                    tSpinX = sign*A*sech((x-6.0)/la)*cos(phi)
##                    tSpinY = sign*A*sech((x-6.0)/la)*sin(phi)
##                    tSpinZ = sign*tanh((x-6.0)/la)

                if sign == 1:
                    tSpinX = 1/sqrt(1+(x-self.nSpinsX/2.0)**2/la**2)
                    tSpinY = 0
                    tSpinZ = sqrt(1-1/(1+(x-self.nSpinsX/2.0)**2/la**2))
                    
                if sign == -1:
                    tSpinX = sign*1/sqrt(1+(x-self.nSpinsX/2.0)**2/la**2)
                    tSpinY = 0
                    tSpinZ = sign*sqrt(1-1/(1+(x-self.nSpinsX/2.0)**2/la**2))

##                if sign == 1:
##                    tSpinX = (x-5.99)/abs(x-5.99)*sech((x-6.0)/la)*cos(phi)
##                    tSpinY = (x-5.99)/abs(x-5.99)*sech((x-6.0)/la)*sin(phi)
##                    tSpinZ = tanh((x-6.0)/la)
##                    
##                if sign == -1:
##                    tSpinX = (x-5.99)/abs(x-5.99)*sech((x-6.0)/la)*cos(phi)
##                    tSpinY = sign*(x-5.99)/abs(x-5.99)*sech((x-6.0)/la)*sin(phi)
##                    tSpinZ = sign*tanh((x-6.0)/la)

##                tSpinX = sech((x-4.0)/la)
##                tSpinY = tanh((x-4.0)/la)
##                tSpinZ = sech((x-4.0)/la)

                self.spinArray[x][y] = vector(tSpinX,tSpinY,tSpinZ)

##                print self.spinArray[x][y]

    def setCouplings(self, Ja, Jbx, Jby, Jad):          # sets the couplings of neighboring spins and creates two coupling arrays.
        self.Ja = Ja                                    # Unit Cell interior region coupling. Unit Cell Object specific. Currently a constant.
        self.Jbx = Jbx                                  # Unit Cell boundary coupling., between unit cells, in the x direction.
        self.Jby = Jby                                  # in the y direction...
        self.Jad = Jad                                  # separation of couplings between middle spins.

        self.couplingXArray = [ None ] * self.nSpinsX
        for x in xrange(self.nSpinsX):
            self.couplingXArray[x] = [ 0.0 ] * self.nSpinsX

        self.couplingYArray = [ None ] * self.nSpinsY
        for x in xrange(self.nSpinsY):
            self.couplingYArray[x] = [ 0.0 ] * self.nSpinsY

        for i in range(self.nSpinsX):
            for j in range(self.nSpinsX):
                if      i - j == 1:
                    if      i%self.stripeSpacingX == 0 and j%self.stripeSpacingX == (self.stripeSpacingX-1):    self.couplingXArray[i][j] += self.Jbx
                    else:                                                                                       self.couplingXArray[i][j] += self.Ja 
                if      j - i == 1:
                    if      j%self.stripeSpacingX == 0 and i%self.stripeSpacingX == (self.stripeSpacingX-1):    self.couplingXArray[i][j] += self.Jbx
                    else:                                                                                       self.couplingXArray[i][j] += self.Ja
                if      i - j == self.nSpinsX - 1 or j - i == self.nSpinsX - 1:                                 self.couplingXArray[i][j] += self.Jbx

        i=0
        j=0

        for i in range(self.nSpinsY):
            for j in range(self.nSpinsY):
                if      i - j == 1:
                    if      i%self.stripeSpacingY == 0 and j%self.stripeSpacingY == (self.stripeSpacingY-1):    self.couplingYArray[i][j] += self.Jby
                    else:                                                                                       self.couplingYArray[i][j] += self.Ja
                if      j - i == 1:
                    if      j%self.stripeSpacingY == 0 and i%self.stripeSpacingY == (self.stripeSpacingY-1):    self.couplingYArray[i][j] += self.Jby
                    else:                                                                                       self.couplingYArray[i][j] += self.Ja
                if      i - j == self.nSpinsY - 1 or j - i == self.nSpinsY - 1:                                 self.couplingYArray[i][j] += self.Jby
                
##        self.couplingXArray[0][11] = 0
##        self.couplingXArray[11][0] = 0
        # periodic boundary conditions in the x direction


    def getTorques(self, spinArray, t):
        for x in xrange(self.nSpinsX):
            for y in xrange(self.nSpinsY):
                fieldsum = vector(0.0,0.0,0.0)
                
                for i in xrange(self.nSpinsX):
                    fieldsum += self.couplingXArray[x][i]*spinArray[i][y]

                    
                for j in xrange(self.nSpinsY):
                    fieldsum += self.couplingYArray[y][j]*spinArray[x][j]
                    
                self.meanFieldArray[x][y] = fieldsum
        
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                self.torqueArray[x][y] = -cross(spinArray[x][y],self.meanFieldArray[x][y])
        
        ## + vector(0.0,0.0,2.0)


    def timeEvolve(self, t, dt, temp, EBu, EBd):                      # Takes us one step forward in time by adding the torque to our spins and renormalizing.
        self.temp = temp                                                # adds a random torque <--> kick to each spin, with a scale set by temp
            
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                randomU = random()                          # perturbed by unit ball
##                randomV = random()                          # we produce a uniform distribution on the unit ball
##                theta = 2.0*3.1416*randomU
##                phi = arccos(2.0*randomV-1.0)
##                kick = self.temp*norm(vector(sin(theta)*cos(phi),cos(theta)*sin(phi),cos(phi)))
##                            
##                if x+y == 0:
##                    self.tempArray[x][y][0] = kick        # this kick is constant in magnitude but random in direction
##                if x+y == 1:
##                    self.tempArray[x][y][0] = kick        # this kick is constant in magnitude but random in direction

##        self.Jbx = self.Jbx - .0001
##        self.setCouplings(self.Ja, self.Jbx, self.Jby, self.Jad)

        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                self.torqueTempArray[x][y] = -cross(self.spinArray[x][y],self.tempArray[x][y][0])

        k0 = multiply(dt,self.torqueTempArray)
        
        self.getTorques(self.spinArray, t)              # for time dep solns: set t = tn
        k1 = multiply(dt,self.torqueArray)
        self.getTorques(self.spinArray + k1/2.0, t + dt/2.0)        # for time dep solns: set t = tn + dt/2.0
        k2 = multiply(dt,self.torqueArray)
        self.getTorques(self.spinArray + k2/2.0, t + dt/2.0)        # for time dep solns: set t = tn + dt/2.0
        k3 = multiply(dt,self.torqueArray)
        self.getTorques(self.spinArray + k3, t + dt)                # for time dep solns: set t = tn + dt
        k4 = multiply(dt,self.torqueArray)

        self.spinArray = self.spinArray + k0 + k1/6.0 + k2/3.0 + k3/3.0 + k4/6.0
        
##        print (self.spinArray[0][0][0][0],self.spinArray[0][0][0][1],self.spinArray[0][0][0][2])
##        print (self.spinArray[1][0][0][0],self.spinArray[1][0][0][1],self.spinArray[1][0][0][2])
##        print (self.spinArray[2][0][0][0],self.spinArray[2][0][0][1],self.spinArray[2][0][0][2])
##        print (self.spinArray[0][1][0][0],self.spinArray[0][1][0][1],self.spinArray[0][1][0][2])
##        print (self.spinArray[1][1][0][0],self.spinArray[1][1][0][1],self.spinArray[1][1][0][2])
##        print (self.spinArray[2][1][0][0],self.spinArray[2][1][0][1],self.spinArray[2][1][0][2])
##        print ""
##        print self.Jbx
##        print ""

##        print dot(self.spinArray[1][0][0],self.meanFieldArray[1][0][0])
##        print self.spinArray[0][0][0]-self.spinArray[1][0][0]+self.spinArray[2][0][0]

##        print self.spinArray[0][0][0][1]+self.spinArray[1][0][0][1]+self.spinArray[2][0][0][1]
##        print dot(self.spinArray[0][0][0],self.meanFieldArray[0][0][0])+dot(self.spinArray[1][0][0],self.meanFieldArray[1][0][0])+dot(self.spinArray[2][0][0],self.meanFieldArray[2][0][0])

##        for x in range(self.nSpinsX):           # normalizes the state
##            for y in range(self.nSpinsY):
####                if x == 1:
####                    self.spinArray[x][y][0] = self.magS2*norm(self.spinArray[x][y][0])
####                else:
##                self.spinArray[x][y][0] = norm(self.spinArray[x][y])
##                print(self.spinArray[x][y][0])

    # Data for Export Methods
   
    def returnState(self):
        return self.spinArray

    def returnPhaseArray(self):
        return self.phaseArray

    def returnPhiArray(self):
        return self.phiArray

    def returnzArray(self):
        return self.zArray

    def returnMeanFieldArray(self):
        return self.meanFieldArray

    def returnTorqueArray(self):
        return self.torqueArray

    def returnPhiScale(self):
        return self.phiScale
    
    def returnSpinSum(self):
        spinSum = vector(0.0,0.0,0.0)
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                spinSum += self.spinArray[x][y][0]
        return spinSum

    def returnEnergies(self):
        energyArray = zeros((self.nSpinsX,self.nSpinsY))                 # scalars
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                energyArray[x][y] = -dot(self.spinArray[x][y][0],self.meanFieldArray[x][y][0])
        return energyArray

    def returnBondXEnergies(self):
        spinArray = self.returnState()
        
        energyArrayX = zeros((self.nSpinsX+1,self.nSpinsY))
        for x in range(self.nSpinsX+1):
            for y in range(self.nSpinsY):
                energyArrayX[x][y] = -dot(spinArray[x%self.nSpinsX][y%self.nSpinsY][0],spinArray[(x-1)%self.nSpinsX][y%self.nSpinsY][0])

        return energyArrayX

    def returnBondYEnergies(self):
        spinArray = self.returnState()
        
        energyArrayY = zeros((self.nSpinsX,self.nSpinsY+1))
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY+1):              
                energyArrayY[x][y] = -dot(spinArray[x%self.nSpinsX][y%self.nSpinsY][0],spinArray[x%self.nSpinsX][(y-1)%self.nSpinsY][0])
        return energyArrayY

    def returnTotalEnergy(self):
        energyArray = self.returnEnergies()
        totEnergy = 0.0
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                totEnergy = totEnergy + energyArray[x][y]
        return totEnergy/2.0

    def returnCouplingX(self):
        return self.couplingXArray

    def returnCouplingY(self):
        return self.couplingYArray

    def returnStripeSpacingX(self):
        return self.stripeSpacingX

    def returnStripeSpacingY(self):
        return self.stripeSpacingY









    # Perturbation Methods


    def randomizeState(self):
        for x in range(self.nSpinsX):
            for y in range(self.nSpinsY):
                randomU = random()
                randomV = random()
                
                theta = 2.0*3.1416*randomU
                phi = arccos(2.0*randomV-1.0)
                
                self.spinArray[x][y][0] = norm(vector(sin(theta)*cos(phi),cos(theta)*sin(phi),cos(phi)))





























##        self.temp = 0.00
##        self.EBu = 0.00
##        self.EBd = 0.00
##        self.EBaxis = vector(0.0,0.0,1.0)


##        w = (30.0/10.0)*2.0*pi
##        A = .7
##        arg = A*sin(w*t)
##        EField = self.EBu*vector(0.0,0.0,cos(w*t))
##        appliedFieldArray = [[EField],[EField],[EField]]

                
##                if self.temp != 0.0:                            # adds a random torque <--> kick to each spin, with a scale set by temp.
##                    randomU = random()                          # perturbed by unit ball
##                    randomV = random()                          # we produce a uniform distribution on the unit ball
##                    theta = 2.0*3.1416*randomU
##                    phi = arccos(2.0*randomV-1.0)
##                    kick = self.temp*norm(vector(sin(theta)*cos(phi),cos(theta)*sin(phi),cos(phi)))
##                    
##                    self.meanFieldArray[x][y] += kick        # this kick is constant in magnitude but random in direction

##                if self.EBu != 0.0 or self.EBd != 0.0:
##                    self.meanFieldArray[x][y] += appliedFieldArray[x][y]



    

##    def setStateManually(self, k, baseSigma, zArray, sigmaArray, phaseArray):
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                xModXSpacing = x%self.stripeSpacingX
##                yModYSpacing = y%self.stripeSpacingY
##                xR = x - xModXSpacing
##                yR = y - yModYSpacing
##
##                z = zArray[xModXSpacing][yModYSpacing]
##
##                tempSigma = baseSigma*sigmaArray[xModXSpacing][yModYSpacing]
##                tempAngle = phaseArray[xModXSpacing][yModYSpacing] + k.x*xR + k.y*yR
##                
##                tempSpinX = tempSigma*cos(tempAngle)
##                tempSpinY = tempSigma*sin(tempAngle)
##                tempSpinZ = z*(1.0-tempSigma**2)**.5
##               
##                self.spinArray[x][y] = (tempSpinX,tempSpinY,tempSpinZ)
##
##
##

##
##        # Output Arrays
##        self.energyArray = zeros((nSpinsX,nSpinsY),Float32)                 # scalars
##        self.maxEnergyArray = -4.0*ones((nSpinsX,nSpinsY),Float32)
##        self.minEnergyArray = zeros((nSpinsX,nSpinsY),Float32)
##
##        self.TrackSpins = TrackSpins




##    def setStateFromH(self, k, baseSigma, n):
##        self.Ham.makeHamiltonian(k)
##
##        if n == -1:     state = self.Ham.getGroundState()
##        else:           state = self.Ham.getState(n)
##
##        xComponentUCArray = state.real
##        yComponentUCArray = state.imaginary
##
##        for x in range(self.stripeSpacingX):
##            for y in range(self.stripeSpacingY):
##                z = (-1.0)**(x%2 + y%2)
##                self.phaseArray[x][y] = k.x*x + k.y*y + (1-z)*pi/2
##
##        for x in range(self.nSpinsX):
##            for y in range(self.nSpinsY):
##                xModXSpacing = x%self.stripeSpacingX
##                yModYSpacing = y%self.stripeSpacingY
##                xR = x - xModXSpacing
##                yR = y - yModYSpacing
##
##                z = (-1.0)**(xModXSpacing%2 + yModYSpacing%2)
##
##                self.sigmaArray[xModXSpacing][yModYSpacing] = baseSigma*(xComponentUCArray[xModXSpacing][yModYSpacing]**2 + yComponentUCArray[xModXSpacing][yModYSpacing]**2)**0.5
##                self.phaseArray[xModXSpacing][yModYSpacing] = k.x*xModXSpacing + k.y*yModYSpacing + (1-z)*pi/2
##
##                tempSigma = self.sigmaArray[xModXSpacing][yModYSpacing]
##                tempAngle = self.phaseArray[xModXSpacing][yModYSpacing] + k.x*xR + k.y*yR
##
##                tempSpinX = baseSigma*(xComponentUCArray[xModXSpacing][yModYSpacing]*cos(tempAngle) - yComponentUCArray[xModXSpacing][yModYSpacing]*sin(tempAngle))
####                tempSpinX = baseSigma*xComponentUCArray[xModXSpacing][yModYSpacing]
##                tempSpinY = baseSigma*(xComponentUCArray[xModXSpacing][yModYSpacing]*sin(tempAngle) + yComponentUCArray[xModXSpacing][yModYSpacing]*cos(tempAngle))
####                tempSpinY = baseSigma*yComponentUCArray[xModXSpacing][yModYSpacing]
##                tempSpinZ = z*(1.0-tempSigma**2)**.5
##               
##                self.spinArray[x][y] = (tempSpinX,tempSpinY,tempSpinZ)
##
##
